{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Copyright (c) Microsoft Corporation. All rights reserved.*  \n",
    "\n",
    "*Licensed under the MIT License.*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Natural Language Inference on MultiNLI Dataset using BERT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Before You Start\n",
    "\n",
    "The running time shown in this notebook is on a Standard_NC24s_v3 Azure Deep Learning Virtual Machine with 4 NVIDIA Tesla V100 GPUs. \n",
    "> **Tip:** If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n",
    "\n",
    "The table below provides some reference running time on different machine configurations.  \n",
    "\n",
    "|QUICK_RUN|Machine Configurations|Running time|\n",
    "|:---------|:----------------------|:------------|\n",
    "|True|4 **CPU**s, 14GB memory| ~ 15 minutes|\n",
    "|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5 minutes|\n",
    "|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 10.5 hours|\n",
    "|False|4 NVIDIA Tesla V100 GPUs, 64GB GPU memory| ~ 2.5 hours|\n",
    "\n",
    "If you run into CUDA out-of-memory error, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n",
    "QUICK_RUN = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary\n",
    "In this notebook, we demostrate using [BERT](https://arxiv.org/abs/1810.04805) to perform Natural Language Inference (NLI). We use the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset and the task is to classify sentence pairs into three classes: contradiction, entailment, and neutral.   \n",
    "The figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. It concatenates the tokens in each sentence pairs and separates the sentences by the [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.\n",
    "<img src=\"https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG\">"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import random\n",
    "import numpy as np\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "import torch\n",
    "\n",
    "nlp_path = os.path.abspath('../../')\n",
    "if nlp_path not in sys.path:\n",
    "    sys.path.insert(0, nlp_path)\n",
    "\n",
    "from utils_nlp.models.bert.sequence_classification import BERTSequenceClassifier\n",
    "from utils_nlp.models.bert.common import Language, Tokenizer\n",
    "from utils_nlp.dataset.multinli import load_pandas_df\n",
    "from utils_nlp.common.timer import Timer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "TRAIN_DATA_USED_PERCENT = 1\n",
    "DEV_DATA_USED_PERCENT = 1\n",
    "NUM_EPOCHS = 2\n",
    "\n",
    "if QUICK_RUN:\n",
    "    TRAIN_DATA_USED_PERCENT = 0.001\n",
    "    DEV_DATA_USED_PERCENT = 0.01\n",
    "    NUM_EPOCHS = 1\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    BATCH_SIZE = 32\n",
    "else:\n",
    "    BATCH_SIZE = 16\n",
    "\n",
    "# set random seeds\n",
    "RANDOM_SEED = 42\n",
    "random.seed(RANDOM_SEED)\n",
    "np.random.seed(RANDOM_SEED)\n",
    "torch.manual_seed(RANDOM_SEED)\n",
    "num_cuda_devices = torch.cuda.device_count()\n",
    "if num_cuda_devices > 1:\n",
    "    torch.cuda.manual_seed_all(RANDOM_SEED)\n",
    "\n",
    "# model configurations\n",
    "LANGUAGE = Language.ENGLISH\n",
    "TO_LOWER = True\n",
    "MAX_SEQ_LENGTH = 128\n",
    "\n",
    "# optimizer configurations\n",
    "LEARNING_RATE= 5e-5\n",
    "WARMUP_PROPORTION= 0.1\n",
    "\n",
    "# data configurations\n",
    "TEXT_COL = \"text\"\n",
    "LABEL_COL = \"gold_label\"\n",
    "\n",
    "CACHE_DIR = \"./temp\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Data\n",
    "The MultiNLI dataset comes with three subsets: train, dev_matched, dev_mismatched. The dev_matched dataset are from the same genres as the train dataset, while the dev_mismatched dataset are from genres not seen in the training dataset.   \n",
    "The `load_pandas_df` function downloads and extracts the zip files if they don't already exist in `local_cache_path` and returns the data subset specified by `file_split`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"train\")\n",
    "dev_df_matched = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"dev_matched\")\n",
    "dev_df_mismatched = load_pandas_df(local_cache_path=CACHE_DIR, file_split=\"dev_mismatched\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_df_matched = dev_df_matched.loc[dev_df_matched['gold_label'] != '-']\n",
    "dev_df_mismatched = dev_df_mismatched.loc[dev_df_mismatched['gold_label'] != '-']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training dataset size: 392702\n",
      "Development (matched) dataset size: 9815\n",
      "Development (mismatched) dataset size: 9832\n",
      "\n",
      "   gold_label                                          sentence1  \\\n",
      "0     neutral  Conceptually cream skimming has two basic dime...   \n",
      "1  entailment  you know during the season and i guess at at y...   \n",
      "2  entailment  One of our number will carry out your instruct...   \n",
      "3  entailment  How do you know? All this is their information...   \n",
      "4     neutral  yeah i tell you what though if you go price so...   \n",
      "\n",
      "                                           sentence2  \n",
      "0  Product and geography are what make cream skim...  \n",
      "1  You lose the things to the following level if ...  \n",
      "2  A member of my team will execute your orders w...  \n",
      "3                  This information belongs to them.  \n",
      "4           The tennis shoes have a range of prices.  \n"
     ]
    }
   ],
   "source": [
    "print(\"Training dataset size: {}\".format(train_df.shape[0]))\n",
    "print(\"Development (matched) dataset size: {}\".format(dev_df_matched.shape[0]))\n",
    "print(\"Development (mismatched) dataset size: {}\".format(dev_df_mismatched.shape[0]))\n",
    "print()\n",
    "print(train_df[['gold_label', 'sentence1', 'sentence2']].head())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Concatenate the first and second sentences to form the input text."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text</th>\n",
       "      <th>gold_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>(Conceptually cream skimming has two basic dim...</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>(you know during the season and i guess at at ...</td>\n",
       "      <td>entailment</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>(One of our number will carry out your instruc...</td>\n",
       "      <td>entailment</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>(How do you know? All this is their informatio...</td>\n",
       "      <td>entailment</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>(yeah i tell you what though if you go price s...</td>\n",
       "      <td>neutral</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                text  gold_label\n",
       "0  (Conceptually cream skimming has two basic dim...     neutral\n",
       "1  (you know during the season and i guess at at ...  entailment\n",
       "2  (One of our number will carry out your instruc...  entailment\n",
       "3  (How do you know? All this is their informatio...  entailment\n",
       "4  (yeah i tell you what though if you go price s...     neutral"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df[TEXT_COL] = list(zip(train_df['sentence1'], train_df['sentence2']))\n",
    "dev_df_matched[TEXT_COL] = list(zip(dev_df_matched['sentence1'], dev_df_matched['sentence2']))\n",
    "dev_df_mismatched[TEXT_COL] = list(zip(dev_df_mismatched['sentence1'], dev_df_mismatched['sentence2']))\n",
    "train_df[[TEXT_COL, LABEL_COL]].head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = train_df.sample(frac=TRAIN_DATA_USED_PERCENT).reset_index(drop=True)\n",
    "dev_df_matched = dev_df_matched.sample(frac=DEV_DATA_USED_PERCENT).reset_index(drop=True)\n",
    "dev_df_mismatched = dev_df_mismatched.sample(frac=DEV_DATA_USED_PERCENT).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tokenize and Preprocess\n",
    "Before training, we tokenize the sentence texts and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 392702/392702 [03:25<00:00, 1907.47it/s]\n",
      "100%|██████████| 9815/9815 [00:05<00:00, 1961.13it/s]\n",
      "100%|██████████| 9832/9832 [00:05<00:00, 1837.42it/s]\n"
     ]
    }
   ],
   "source": [
    "tokenizer= Tokenizer(LANGUAGE, to_lower=TO_LOWER, cache_dir=CACHE_DIR)\n",
    "\n",
    "train_tokens = tokenizer.tokenize(train_df[TEXT_COL])\n",
    "dev_matched_tokens = tokenizer.tokenize(dev_df_matched[TEXT_COL])\n",
    "dev_mismatched_tokens = tokenizer.tokenize(dev_df_mismatched[TEXT_COL])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In addition, we perform the following preprocessing steps in the cell below:\n",
    "\n",
    "* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary\n",
    "* Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence\n",
    "* Pad or truncate the token lists to the specified max length\n",
    "* Return mask lists that indicate paddings' positions\n",
    "* Return token type id lists that indicate which sentence the tokens belong to\n",
    "\n",
    "*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_token_ids, train_input_mask, train_token_type_ids = \\\n",
    "    tokenizer.preprocess_classification_tokens(train_tokens, max_len=MAX_SEQ_LENGTH)\n",
    "dev_matched_token_ids, dev_matched_input_mask, dev_matched_token_type_ids = \\\n",
    "    tokenizer.preprocess_classification_tokens(dev_matched_tokens, max_len=MAX_SEQ_LENGTH)\n",
    "dev_mismatched_token_ids, dev_mismatched_input_mask, dev_mismatched_token_type_ids = \\\n",
    "    tokenizer.preprocess_classification_tokens(dev_mismatched_tokens, max_len=MAX_SEQ_LENGTH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_encoder = LabelEncoder()\n",
    "train_labels = label_encoder.fit_transform(train_df[LABEL_COL])\n",
    "num_labels = len(np.unique(train_labels))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train and Predict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier = BERTSequenceClassifier(language=LANGUAGE,\n",
    "                                    num_labels=num_labels,\n",
    "                                    cache_dir=CACHE_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:   0%|          | 1/12272 [00:10<35:06:53, 10.30s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1/2; batch:1->1228/12272; average training loss:1.199178\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  10%|█         | 1229/12272 [07:20<1:03:16,  2.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1/2; batch:1229->2456/12272; average training loss:0.783637\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  20%|██        | 2457/12272 [14:28<55:44,  2.93it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1/2; batch:2457->3684/12272; average training loss:0.692243\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  30%|███       | 3685/12272 [21:37<48:36,  2.94it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1/2; batch:3685->4912/12272; average training loss:0.653206\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  40%|████      | 4913/12272 [28:45<41:36,  2.95it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1/2; batch:4913->6140/12272; average training loss:0.625751\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  50%|█████     | 6141/12272 [35:54<34:44,  2.94it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1/2; batch:6141->7368/12272; average training loss:0.605123\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  60%|██████    | 7369/12272 [42:58<27:46,  2.94it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1/2; batch:7369->8596/12272; average training loss:0.590521\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  70%|███████   | 8597/12272 [50:07<20:52,  2.93it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1/2; batch:8597->9824/12272; average training loss:0.577829\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  80%|████████  | 9825/12272 [57:14<13:46,  2.96it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1/2; batch:9825->11052/12272; average training loss:0.566418\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  90%|█████████ | 11053/12272 [1:04:20<06:53,  2.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:1/2; batch:11053->12272/12272; average training loss:0.556558\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration: 100%|██████████| 12272/12272 [1:11:21<00:00,  2.88it/s]\n",
      "Iteration:   0%|          | 1/12272 [00:00<1:12:29,  2.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:2/2; batch:1->1228/12272; average training loss:0.319802\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  10%|█         | 1229/12272 [07:09<1:02:29,  2.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:2/2; batch:1229->2456/12272; average training loss:0.331876\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  20%|██        | 2457/12272 [14:15<55:22,  2.95it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:2/2; batch:2457->3684/12272; average training loss:0.333463\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  30%|███       | 3685/12272 [21:21<48:41,  2.94it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:2/2; batch:3685->4912/12272; average training loss:0.331817\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  40%|████      | 4913/12272 [28:25<41:26,  2.96it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:2/2; batch:4913->6140/12272; average training loss:0.327940\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  50%|█████     | 6141/12272 [35:31<34:34,  2.96it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:2/2; batch:6141->7368/12272; average training loss:0.325802\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  60%|██████    | 7369/12272 [42:36<27:48,  2.94it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:2/2; batch:7369->8596/12272; average training loss:0.324641\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  70%|███████   | 8597/12272 [49:42<20:53,  2.93it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:2/2; batch:8597->9824/12272; average training loss:0.322036\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  80%|████████  | 9825/12272 [56:44<13:50,  2.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:2/2; batch:9825->11052/12272; average training loss:0.321205\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration:  90%|█████████ | 11053/12272 [1:03:49<06:54,  2.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch:2/2; batch:11053->12272/12272; average training loss:0.319237\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration: 100%|██████████| 12272/12272 [1:10:52<00:00,  2.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training time : 2.374 hrs\n"
     ]
    }
   ],
   "source": [
    "with Timer() as t:\n",
    "    classifier.fit(token_ids=train_token_ids,\n",
    "                   input_mask=train_input_mask,\n",
    "                   token_type_ids=train_token_type_ids,\n",
    "                   labels=train_labels,\n",
    "                   num_epochs=NUM_EPOCHS,\n",
    "                   batch_size=BATCH_SIZE,\n",
    "                   lr=LEARNING_RATE,\n",
    "                   warmup_proportion=WARMUP_PROPORTION)\n",
    "print(\"Training time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Predict on Test Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration: 100%|██████████| 307/307 [00:40<00:00,  8.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction time : 0.011 hrs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "with Timer() as t:\n",
    "    predictions_matched = classifier.predict(token_ids=dev_matched_token_ids,\n",
    "                                             input_mask=dev_matched_input_mask,\n",
    "                                             token_type_ids=dev_matched_token_type_ids,\n",
    "                                             batch_size=BATCH_SIZE)\n",
    "print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Iteration: 100%|██████████| 308/308 [00:38<00:00,  8.30it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction time : 0.011 hrs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "with Timer() as t:\n",
    "    predictions_mismatched = classifier.predict(token_ids=dev_mismatched_token_ids,\n",
    "                                                input_mask=dev_mismatched_input_mask,\n",
    "                                                token_type_ids=dev_mismatched_token_type_ids,\n",
    "                                                batch_size=BATCH_SIZE)\n",
    "print(\"Prediction time : {:.3f} hrs\".format(t.interval / 3600))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction      0.848     0.865     0.857      3213\n",
      "   entailment      0.894     0.828     0.860      3479\n",
      "      neutral      0.783     0.831     0.806      3123\n",
      "\n",
      "    micro avg      0.841     0.841     0.841      9815\n",
      "    macro avg      0.842     0.841     0.841      9815\n",
      " weighted avg      0.844     0.841     0.842      9815\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions_matched = label_encoder.inverse_transform(predictions_matched)\n",
    "print(classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "               precision    recall  f1-score   support\n",
      "\n",
      "contradiction      0.862     0.863     0.863      3240\n",
      "   entailment      0.878     0.853     0.865      3463\n",
      "      neutral      0.791     0.815     0.803      3129\n",
      "\n",
      "    micro avg      0.844     0.844     0.844      9832\n",
      "    macro avg      0.844     0.844     0.844      9832\n",
      " weighted avg      0.845     0.844     0.845      9832\n",
      "\n"
     ]
    }
   ],
   "source": [
    "predictions_mismatched = label_encoder.inverse_transform(predictions_mismatched)\n",
    "print(classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3))"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "nlp_gpu",
   "language": "python",
   "name": "nlp_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
